# 模型相关工具函数
# 包含量化配置等功能
from transformers import BitsAndBytesConfig
import torch


def get_quantization_config(load_in_4bit=True):
    """获取量化配置"""
    if load_in_4bit:
        return BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4',
            bnb_4bit_compute_dtype=torch.bfloat16
        )
    return None